"""Variational AutoEncoder Implementation"""
from functools import partial
from typing import Any, Tuple, List, Dict, Union, Type, Optional, Callable

import jax
import jax.numpy as jnp
from jax import nn
import haiku as hk

from diffgro.common.models.utils import (
    act2fn,
    init_he_uniform,
    init_he_normal,
)


# ============================== VAE ============================= #

class Encoder(hk.Module):
    def __init__(self):
        super().__init__()

    def sample(self, mean: jax.Array, std: jax.Array, deterministic: bool):
        if deterministic:
            return mean
        return mean + std * jax.random.normal(hk.next_rng_key(), std.shape)


class MLPEncoder(Encoder):
    def __init__(
        self,
        net_arch: List[int],
        emb_dim: int, # input embedding dimension
        ctx_dim: int, # context dimension
        batch_keys: List[str],
        activation_fn: str = 'mish',
    ):
        super().__init__()
        self.net_arch = net_arch
        self.emb_dim = emb_dim
        self.ctx_dim = ctx_dim
        self.batch_keys = batch_keys
        self.activation_fn = activation_fn

    def __call__(self, batch_dict: Dict[str, jax.Array]) -> Tuple[jax.Array]:
        inp = []
        for key in self.batch_keys:
            c = batch_dict[key]
            c = hk.Linear(self.emb_dim*2, **init_he_normal())(c)
            c = act2fn[self.activation_fn](c)
            c = hk.Linear(self.emb_dim, **init_he_normal())(c)
            inp.append(c)
        # [batch, ctx_dim]
        inp = jnp.concatenate(inp, axis=-1)

        for ind, dim in enumerate(self.net_arch):
            inp = hk.Linear(dim, **init_he_normal())(inp)
            inp = act2fn[self.activation_fn](inp)
        out = inp

        mean = hk.Linear(self.ctx_dim, **init_he_normal())(out)
        log_std = hk.Linear(self.ctx_dim, **init_he_normal())(out)
        log_std = jnp.clip(log_std, a_min=-20, a_max=2)
        std = jnp.exp(log_std)
        return mean, std


class LSTMEncoder(Encoder):
    def __init__(
        self,
        horizon: int,
        emb_dim: int, # input embedding dimension
        hid_dim: int, # lstm hidden dimension
        ctx_dim: int, # context dimension
        batch_keys: List[str],
        activation_fn: str = 'mish',
    ):
        super().__init__()
        self.horizon = horizon
        self.emb_dim = emb_dim
        self.ctx_dim = ctx_dim
        self.hid_dim = hid_dim
        self.batch_keys = batch_keys
        self.activation_fn = activation_fn

    def __call__(self, batch_dict: Dict[str, jax.Array]) -> Tuple[jax.Array]:
        inp = []
        for key in self.batch_keys:
            c = batch_dict[key]
            c = hk.Linear(self.emb_dim*2, **init_he_normal())(c)
            c = act2fn[self.activation_fn](c)
            c = hk.Linear(self.emb_dim, **init_he_normal())(c)
            inp.append(c)
        # [batch, horizon, ctx_dim]
        inp = jnp.concatenate(inp, axis=-1)

        # LSTM module
        core = hk.LSTM(self.hid_dim)
        out, state = hk.dynamic_unroll(core, inp, core.initial_state(self.horizon))
        # retrieve last output : [batch, ctx_dim]
        out = out[:,-1,:]
        
        mean = hk.Linear(self.ctx_dim, **init_he_normal())(out)
        log_std = hk.Linear(self.ctx_dim, **init_he_normal())(out)
        log_std = jnp.clip(log_std, a_min=-20, a_max=2)
        std = jnp.exp(log_std)
        return mean, std


class DiffusionVAE(hk.Module):
    def __init__(
        self,
        encoder: hk.Module,
        decoder: hk.Module,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def __call__(
        self,
        x_t: jax.Array,
        batch_dict: Dict[str, Dict[str, jax.Array]], # encoder dict, decoder dict
        t: jax.Array,
        ctx: jax.Array = None,
        denoise: bool = False,
        deterministic: bool = False,
    ) -> Tuple[Tuple[jax.Array], Dict[str, jax.Array]]:
        mean, std = None, None
        # 1. encoder inference
        if ctx is None:
            mean, std = self.encoder(batch_dict['enc'])
            ctx = self.encoder.sample(mean, std, deterministic)
        
        if 'dec' not in batch_dict.keys():
            return (mean, std, None), {}

        # repeat ctx as x_t
        ctx = jnp.repeat(ctx, int(x_t.shape[0] / ctx.shape[0]), axis=0)
        
        # 2. decoder inference
        batch_dict['dec'].update({'ctx': ctx})
        logits, info = self.decoder(x_t, batch_dict['dec'], t, denoise, deterministic)
        return (mean, std, logits), info
        
# =================================== VQVAE ================================ #

class VectorQuantizer(hk.Module):
    def __init__(
        self,
        emb_dim: int,
        emb_num: int,
        decay: float = 0.1,
        epsilon: float = 1e-5,
    ):
        super().__init__()
        self.emb_dim = emb_dim
        self.emb_num = emb_num
        self.decay = decay
        self.epsilon = epsilon

        self._ema_cluster_size = hk.ExponentialMovingAverage(
            decay=np.array(self.decay, dtype=np.float32), name="ema_cluster_size")
        self._ema_dw = hk.ExponentialMovingAverage(
            decay=np.array(self.decay, dtype=np.float32), name="ema_dw")

    @property
    def embeddings(self):
        initializer = hk.initializers.VarianceScaling(distribution="uniform")
        return hk.get_state(
            "embeddings", [self.emb_dim, self.emb_num], init=initializer)

    @property
    def ema_cluster_size(self):
        self._ema_cluster_size.initialize([self.emb_num], jnp.float32)
        return self._ema_cluster_size
    
    @property
    def ema_dw(self):
        self._ema_dw.initialize([self.emb_dim, self.emb_num], jnp.float32)
        return self._emb_dw
    
    def __call__(self, inputs: jax.Array, is_training: bool) -> Dict[str, jax.Array]:
        embeddings = self.embeddings

        distances = (
            jnp.sum(jnp.square(inp), 1, keepdims=True) -
            2 * jnp.matmul(inp, embeddings) +
            jnp.sum(jnp.square(embeddings), 0, keepdims=True)
        )

        encoding_indices = jnp.argmin(distances, 1)
        encodings = jax.nn.one_hot(encoding_indicies, self.emb_num, dtype=distances.dtype)

        encoding_indices = jnp.reshape(encoding_indicies, inp.shape[:-1])
        quantized = self.quantize(encoding_indicies)
        # loss
        q_latent_loss = jnp.mean(jnp.square(quantized - jax.lax.stop_gradient(inputs)))
        e_latent_loss = jnp.mean(jnp.square(jax.lax.stop_gradient(quantized) - inputs))

        if is_training:
            cluster_size = jnp.sum(encodings, axis=0)
            updated_ema_cluster_size = self.ema_cluster_size(cluster_size)

            dw = jnp.matmul(flat_inputs.T, encodings)
            updated_ema_dw = self.ema_dw(dw)

            n = jnp.sum(updated_ema_cluster_size)
            updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
                                        (n + self.emb_num + self.epsilon) * n)
            
            normalized_updated_ema_w = (updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))
            hk.set_state("embeddings", normalized_updated_ema_w)
        
        quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
        avg_probs = jnp.mean(encodings, 0)
        preplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

        return {
            "quantized": quantized,
            "q_latent_loss": q_latent_loss,
            "e_latent_loss": e_latent_loss,
            "perplexity": perplexity,
            "encodings": encodings,
            "encoding_indices": encoding_indicies,
            "distances": distance
        }
            
    def quantize(self, encoding_indicies: jax.Array):
        w = self.embeddings.swapaxes(1, 0)
        w = jax.device_put(w)
        return w[(encoding_indicies,)]


class VQVAE(hk.Module):
    def __init__(
        self,
        encoder: hk.Module,
        decoder: hk.Module,
        emb_dim: int,
        emb_num: int,
        decay: float = 0.1,
        epsilon: float = 1e-5,
    ):
        self.encoder = encoder
        self.vecquan = VectorQuantizer(emb_dim, emb_num, decay, epsilon)
        self.decoder = decoder

    def __call__(self, batch_dict: Dict[str, jax.Array], is_training: bool) -> Dict[str, jax.Array]:
        z = self.encoder(batch_dict)
        (z_quantized, q_latent_loss, e_latent_loss, _, _, _, _) = self.vecquan(z, is_training)
        recon = self.decoder(z_quantized)
        return {
            "q_latent_loss": q_latent_loss,
            "e_latent_loss": e_latent_loss,
            "z_quantized": z_quantized,
            "recon": recon
        }
